/// # Examples
/// ```rust
/// use hicollections::{ List, ListNode, list };
/// use core::ptr;
///
/// struct Bar {
///     node: ListNode,
/// }
/// struct Foo {
///     val: i32,
///     node1: ListNode,
///     node2: ListNode,
///     bar: Bar,
/// }
///
/// let foo = Foo {
///     val: 1,
///     node1: ListNode::new(),
///     node2: ListNode::new(),
///     bar: Bar { node: ListNode::new() },
/// };
///
/// let mut list = list!(Foo, bar.node);
/// unsafe { list.add_tail(&foo); }
///
/// list.iter_mut().fold(1, |n, foo| {
///     assert_eq!(n, foo.val);
///     foo.val += 1;
///     n + 1
/// });
/// assert_eq!(foo.val, 2);
///
/// let mut list = list!();
/// unsafe { list.add_tail(&foo.node2); }
/// assert_eq!(list.iter().count(), 1);
/// ```

use super::Link;
use core::cell::UnsafeCell;
use core::marker::{PhantomData, PhantomPinned};
use core::ops::{Deref, DerefMut};
use core::ptr::{self, NonNull};

/// 方便构造一个List.
/// ```rust
/// use hicollections::{list, List, ListNode};
/// let list: List<ListNode> = list!();
/// struct Foo {
///     val: i32,
///     node: ListNode,
/// }
/// let list: List<Foo> = list!(Foo, node);
/// ```
#[macro_export]
macro_rules! list {
    ($type: path, $($mem: ident).* ) => {
        $crate::List::<$type>::new(|node| unsafe {
            ::core::ptr::addr_of!((*node).$($mem).*)
        })
    };
    () => {
        $crate::List::<ListNode>::new(|node| node)
    }
}

/// C的实现版本中，List也成为链表中的一个节点，这在Rust中不可行.
/// 原因是List无处不在的所有权转移,实际就是内存拷贝，但程序员无法干预, 导致拷贝之后，整个链表就被破坏了.
/// 节点也是这样的， 加入链表后还需要避免任何所有权转移,节点的增删是unsafe的，但List是管理节点，要完全避免所有权转移非常困难.
/// 因此这里设计为List不在链表环中, 可支持所有权转移后不影响链表，自然也支持Send.
/// 由此也带来一些不方便之处：比如节点无法自删除，必须由List删除, 好处内存安全得到保证.
/// 比如，不会出现在迭代过程中删除某个节点导致迭代器失效的异常情况.
#[repr(C)]
pub struct List<T> {
    head: Option<ListNodeRef>,
    offset: usize,
    mark: PhantomData<*const T>,
}

unsafe impl<T: Sync> Send for List<T> {}
unsafe impl<T: Sync> Sync for List<T> {}

/// 需要在用户数据结构中定义此类型的成员才能加入到List.
#[repr(transparent)]
pub struct ListNode(UnsafeCell<Inner>);

#[repr(C)]
struct Inner {
    prev: Option<ListNodeRef>,
    next: Option<ListNodeRef>,
    pin: PhantomPinned,
    mark: PhantomData<*mut ListNode>,
}

#[derive(Copy, Clone)]
struct ListNodeRef {
    ptr: NonNull<ListNode>,
}

impl ListNodeRef {
    fn as_ptr(&self) -> *mut ListNode {
        self.ptr.as_ptr()
    }
}

impl From<&'_ ListNode> for ListNodeRef {
    fn from(src: &'_ ListNode) -> Self {
        Self {
            ptr: NonNull::from(src),
        }
    }
}

impl Deref for ListNodeRef {
    type Target = ListNode;
    fn deref(&self) -> &Self::Target {
        unsafe { self.ptr.as_ref() }
    }
}

impl DerefMut for ListNodeRef {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { self.ptr.as_mut() }
    }
}

impl<T> List<T> {
    ///
    /// 为什么不定义trait来实现T和ListNode的相互转换呢？因为一个T可能挂接到
    /// 多个链表上，如果定义trait，就需要为每个List包装一个不同的类型实现预
    /// 定义的trait接口.
    /// 相对应的是这里的设计只需要定义不同的函数闭包就可以支持，使用起来简单
    ///
    pub fn new<F>(f: F) -> Self
    where
        F: Fn(*const T) -> *const ListNode,
    {
        Self {
            head: None,
            offset: crate::node_offset(f),
            mark: PhantomData,
        }
    }

    /// # Safety
    /// 使用者保证node未挂载在任何链表中, 在node生命周期结束前从链表中删除,
    /// 并且需要保证删除前不能发生所有权转移.
    pub unsafe fn add_tail<'a>(&mut self, node: &'a T) -> Link<'a> {
        let node = self.node_from(node);
        debug_assert!(!node.linked());
        if let Some(head) = self.head {
            let tail = head.prev().unwrap();
            tail.link_next(node);
            node.link_next(head);
        } else {
            node.link_next(node);
            self.head = Some(node);
        }
        Link::new()
    }

    /// # Safety
    /// 使用者保证node未挂载在任何链表中, 在node生命周期结束前从链表中删除,
    /// 并且需要保证删除前不能发生所有权转移.
    pub unsafe fn add_head<'a>(&mut self, node: &'a T) -> Link<'a> {
        let node = self.node_from(node);
        debug_assert!(!node.linked());
        if let Some(head) = self.head {
            let tail = head.prev().unwrap();
            tail.link_next(node);
            node.link_next(head);
        } else {
            node.link_next(node);
        }
        self.head = Some(node);
        Link::new()
    }

    /// # Safety
    /// 使用者保证pos已经挂载在链表List中, node未挂载在任何链表中, 在node生命周期结束前从链表中删除
    /// 并且需要保证node删除前不能发生所有权转移.
    pub unsafe fn add<'a>(&mut self, node: &'a T, pos: &T) -> Link<'a> {
        let node = self.node_from(node);
        let pos = self.node_from(pos);
        debug_assert!(!node.linked());
        debug_assert!(pos.linked());
        pos.prev().unwrap().link_next(node);
        node.link_next(pos);
        if ptr::eq(self.head.unwrap().as_ptr(), pos.as_ptr()) {
            self.head = Some(node);
        }
        Link::new()
    }

    /// # Safety
    /// 使用者保证node已经挂载在本链表中，或者未挂载在任何链表中
    pub unsafe fn del(&mut self, node: &T) {
        let node = self.node_from(node);
        if let Some(next) = node.next() {
            let prev = node.prev().unwrap();
            prev.link_next(next);
            node.init();
            if ptr::eq(self.head.unwrap().as_ptr(), node.as_ptr()) {
                if ptr::eq(node.as_ptr(), next.as_ptr()) {
                    self.head = None;
                } else {
                    self.head = Some(next);
                }
            }
        }
    }

    pub fn first<'a>(&self) -> Option<&'a T> {
        self.head.map(|node| self.obj_from(node))
    }

    pub fn last<'a>(&self) -> Option<&'a T> {
        self.head.map(|node| self.obj_from(node.prev().unwrap()))
    }

    pub fn empty(&self) -> bool {
        self.head.is_none()
    }

    pub fn singular(&self) -> bool {
        if let Some(head) = self.head {
            if ptr::eq(head.as_ptr(), head.next().unwrap().as_ptr()) {
                return true;
            }
        }
        false
    }

    pub fn move_head(&mut self, other: &mut Self) {
        if let Some(head) = self.head {
            if let Some(other_head) = other.head {
                let tail = head.prev().unwrap();
                let other_tail = other_head.prev().unwrap();
                other_tail.link_next(head);
                tail.link_next(other_head);
            }
            other.head = self.head;
            self.head = None;
        }
    }

    pub fn move_tail(&mut self, other: &mut Self) {
        if let Some(head) = self.head {
            if let Some(other_head) = other.head {
                let tail = head.prev().unwrap();
                let other_tail = other_head.prev().unwrap();
                other_tail.link_next(head);
                tail.link_next(other_head);
            } else {
                other.head = self.head;
            }
            self.head = None;
        }
    }

    pub fn iter(&self) -> Iter<'_, T> {
        Iter::new(self, self.head_node(), self.head_node(), &ListNode::next)
    }

    pub fn rev_iter(&self) -> Iter<'_, T> {
        Iter::new(self, self.tail_node(), self.tail_node(), &ListNode::prev)
    }

    pub fn iter_from(&self, pos: &T) -> Iter<'_, T> {
        Iter::new(self, Some(self.node_from(pos)), self.head_node(), &ListNode::next)
    }

    pub fn rev_iter_from(&self, pos: &T) -> Iter<'_, T> {
        Iter::new(self, Some(self.node_from(pos)), self.tail_node(), &ListNode::prev)
    }

    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
        IterMut::new(self, self.head_node(), self.head_node(), &ListNode::next)
    }

    pub fn rev_iter_mut(&mut self) -> IterMut<'_, T> {
        IterMut::new(self, self.tail_node(), self.tail_node(), &ListNode::prev)
    }

    pub fn iter_mut_from(&mut self, pos: &T) -> IterMut<'_, T> {
        IterMut::new(self, Some(self.node_from(pos)), self.head_node(), &ListNode::next)
    }

    pub fn rev_iter_mut_from(&mut self, pos: &T) -> IterMut<'_, T> {
        IterMut::new(self, Some(self.node_from(pos)), self.tail_node(), &ListNode::prev)
    }

    fn node_from(&self, node: &T) -> ListNodeRef {
        let addr = node as *const _ as usize + self.offset;
        unsafe { &*(addr as *mut ListNode) }.into()
    }

    fn obj_from<'a>(&self, node: ListNodeRef) -> &'a T {
        let addr = node.as_ptr() as usize - self.offset;
        unsafe { &*(addr as *const T) }
    }

    fn obj_from_mut<'a>(&self, node: ListNodeRef) -> &'a mut T {
        let addr = node.as_ptr() as usize - self.offset;
        unsafe { &mut *(addr as *mut T) }
    }

    fn head_node(&self) -> Option<ListNodeRef> {
        self.head
    }

    fn tail_node(&self) -> Option<ListNodeRef> {
        if let Some(head) = self.head {
            return head.prev();
        }
        None
    }
}

/// 实际上这应该是一个unsafe Drop，但当前Rust还不支持.
/// 如果List还可用，ListNode应该在析构前确保从List中删除.
impl Drop for ListNode {
    fn drop(&mut self) {}
}

impl Default for ListNode {
    fn default() -> Self {
        ListNode::new()
    }
}

impl ListNode {
    pub const fn new() -> Self {
        Self(UnsafeCell::new(Inner {
            prev: None,
            next: None,
            pin: PhantomPinned,
            mark: PhantomData,
        }))
    }

    fn prev(&self) -> Option<ListNodeRef> {
        self.inner().prev
    }

    fn next(&self) -> Option<ListNodeRef> {
        self.inner().next
    }

    fn link_next(&self, next: ListNodeRef) {
        self.set_next(Some(next));
        next.set_prev(Some(self.into()));
    }

    fn set_next(&self, next: Option<ListNodeRef>) {
        let inner = unsafe { &mut *self.0.get() };
        inner.next = next;
    }

    fn set_prev(&self, prev: Option<ListNodeRef>) {
        let inner = unsafe { &mut *self.0.get() };
        inner.prev = prev;
    }

    fn init(&self) {
        let inner = unsafe { &mut *self.0.get() };
        inner.next = None;
        inner.prev = None;
    }

    fn linked(&self) -> bool {
        self.inner().next.is_some()
    }

    fn inner(&self) -> &Inner {
        unsafe { &*self.0.get() }
    }
}

pub struct Iter<'a, T> {
    list: &'a List<T>,
    pos: Option<ListNodeRef>,
    end: Option<ListNodeRef>,
    next: &'static dyn Fn(&ListNode) -> Option<ListNodeRef>,
}

impl<'a, T> Iter<'a, T> {
    fn new(
        list: &'a List<T>,
        pos: Option<ListNodeRef>,
        end: Option<ListNodeRef>,
        next: &'static dyn Fn(&ListNode) -> Option<ListNodeRef>,
    ) -> Self {
        Self {
            list,
            pos,
            end,
            next,
        }
    }
}

impl<'a, T> Iterator for Iter<'a, T> {
    type Item = &'a T;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(item) = self.pos {
            self.pos = match (self.next)(&item) {
                Some(next) => {
                    if !ptr::eq(next.as_ptr(), self.end.unwrap().as_ptr()) {
                        Some(next)
                    } else {
                        None
                    }
                }
                None => None,
            };
            Some(self.list.obj_from(item))
        } else {
            None
        }
    }
}

pub struct IterMut<'a, T> {
    list: &'a mut List<T>,
    pos: Option<ListNodeRef>,
    end: Option<ListNodeRef>,
    next: &'static dyn Fn(&ListNode) -> Option<ListNodeRef>,
}

impl<'a, T> IterMut<'a, T> {
    fn new(
        list: &'a mut List<T>,
        pos: Option<ListNodeRef>,
        end: Option<ListNodeRef>,
        next: &'static dyn Fn(&ListNode) -> Option<ListNodeRef>,
    ) -> Self {
        Self {
            list,
            pos,
            end,
            next,
        }
    }
}

impl<'a, T> Iterator for IterMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(item) = self.pos {
            self.pos = match (self.next)(&item) {
                Some(next) => {
                    if !ptr::eq(next.as_ptr(), self.end.unwrap().as_ptr()) {
                        Some(next)
                    } else {
                        None
                    }
                }
                None => None,
            };
            Some(self.list.obj_from_mut(item))
        } else {
            None
        }
    }
}

#[cfg(test)]
mod test {
    extern crate std;
    use crate::ListNode;
    use std::ptr;

    #[test]
    fn test_push() {
        struct Foo {
            node1: ListNode,
            node2: ListNode,
        }

        let mut list = list!();
        let foo = Foo {
            node1: ListNode::new(),
            node2: ListNode::new(),
        };

        assert!(list.first().is_none());
        assert!(list.last().is_none());
        assert!(list.empty());
        unsafe { list.add_tail(&foo.node1) };
        assert!(list.singular());
        assert!(ptr::eq(&foo.node1, list.first().unwrap()));
        assert!(ptr::eq(&foo.node1, list.last().unwrap()));
        unsafe { list.add_head(&foo.node2) };
        assert!(!list.singular());
        assert!(ptr::eq(&foo.node2, list.first().unwrap()));
        assert!(ptr::eq(&foo.node1, list.last().unwrap()));

        unsafe { list.del(&foo.node1); }
        assert!(list.singular());
        assert!(ptr::eq(&foo.node2, list.first().unwrap()));
        assert!(ptr::eq(&foo.node2, list.last().unwrap()));
        unsafe { list.del(&foo.node2) };
        assert!(list.empty());
        assert!(list.first().is_none());
        assert!(list.last().is_none());

        unsafe {
            list.add_head(&foo.node1);
            list.add_tail(&foo.node2);
        }

        assert!(ptr::eq(&foo.node1, list.first().unwrap()));
        assert!(ptr::eq(&foo.node2, list.last().unwrap()));
    }

    #[test]
    fn test_add() {
        let n1 = ListNode::new();
        let n2 = ListNode::new();
        let mut list = list!();
        unsafe {
            list.add_tail(&n1);
            list.add(&n2, &n1);
        }
        assert!(ptr::eq(&n2, list.first().unwrap()));
        assert!(ptr::eq(&n1, list.last().unwrap()));
    }

    #[test]
    fn test_iter() {
        let n1 = ListNode::new();
        let n2 = ListNode::new();
        let mut list = list!();

        unsafe {
            list.add_tail(&n1);
            list.add_tail(&n2);
        }

        let mut iter = list.iter();
        let next = iter.next();
        assert!(next.is_some());
        assert!(ptr::eq(next.unwrap(), &n1));
        let next = iter.next();
        assert!(next.is_some());
        assert!(ptr::eq(next.unwrap(), &n2));
        assert!(iter.next().is_none());
    }

    #[test]
    fn test_rev_iter() {
        let n1 = ListNode::new();
        let n2 = ListNode::new();
        let mut list = list!();

        unsafe {
            list.add_tail(&n1);
            list.add_tail(&n2);
        }

        let mut iter = list.rev_iter();
        let next1 = iter.next().unwrap();
        let next2 = iter.next().unwrap();

        assert!(ptr::eq(next1, &n2));
        assert!(ptr::eq(next2, &n1));

        assert!(iter.next().is_none());
    }

    #[test]
    fn test_user_node() {
        struct Foo {
            node: ListNode,
            val: i32,
        }

        fn new(val: i32) -> Foo {
            Foo {
                node: ListNode::new(),
                val,
            }
        }

        let mut list = list!(Foo, node);
        let foo1 = new(1);
        let foo2 = new(2);
        let foo3 = new(3);

        unsafe {
            list.add_tail(&foo1);
            list.add_tail(&foo2);
            list.add_tail(&foo3);
        }

        let iter = list.iter();
        iter.fold(1, |n, node| {
            assert_eq!(n, node.val);
            n + 1
        });

        assert_eq!(list.iter().count(), 3);
    }

    #[test]
    fn test_iter_from() {
        let n1 = ListNode::new();
        let n2 = ListNode::new();
        let n3 = ListNode::new();

        let mut list = list!();

        unsafe {
            list.add_tail(&n1);
            list.add(&n2, &n1);
            list.add(&n3, &n1);
        }

        let mut iter = list.iter_from(&n3);
        let next1 = iter.next().unwrap();
        let next2 = iter.next().unwrap();
        assert!(ptr::eq(next1, &n3));
        assert!(ptr::eq(next2, &n1));
        assert!(iter.next().is_none());

        let mut iter = list.rev_iter_from(&n3);
        let next1 = iter.next().unwrap();
        let next2 = iter.next().unwrap();
        assert!(ptr::eq(next1, &n3));
        assert!(ptr::eq(next2, &n2));
        assert!(iter.next().is_none());
    }

    #[test]
    fn test_move() {
        let n1 = ListNode::new();
        let n2 = ListNode::new();
        let n3 = ListNode::new();

        let mut list = list!();

        unsafe {
            list.add_head(&n1);
            list.add_head(&n2);
            list.add_head(&n3);
        }

        let mut list2 = list!();

        list.move_head(&mut list2);
        let mut iter = list2.iter();
        assert!(ptr::eq(iter.next().unwrap(), &n3));
        assert!(ptr::eq(iter.next().unwrap(), &n2));
        assert!(ptr::eq(iter.next().unwrap(), &n1));
        assert!(iter.next().is_none());

        let mut iter = list.iter();
        assert!(iter.next().is_none());
    }

    #[test]
    fn test_link() {
        struct Foo {
            val: i32,
            node: ListNode,
        }

        let mut foo = Foo {
            val: 1,
            node: ListNode::new(),
        };
        let mut list = list!(Foo, node);

        let link = unsafe { list.add_tail(&foo) };

        //foo.val = 100;

        let _bar = &foo; //OK
                         //let foo = &mut foo; //ERRO

        link.drop();

        let foo = &mut foo;
        foo.val = 1;
        unsafe { list.del(foo); }
    }
}
